Tensorflow学习9-3:谷歌inception-v3模型之transfer learning retrain

因为是 transfer learning 操作,所以直到网络的bottleneck部分之前都不需要改变参数和训练。只需要传入图片数据到网络中计算得到结果。再拿到这个结果到后面的全连接层进行训练。

所以训练的内容不多,迭代200次左右就可以达到要求。

1 准备模型文件

.pb模型文件下载地址:
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

解压后得到如下内容:

2 准备retrain训练工程目录

1、在D:/Tensorflow目录下新建retrain文件夹。
在里面新建以下文件夹:
bottleneck:存放瓶颈部分输出的数据,用于全连接层的训练
data
└ train:存放用于训练的用文件夹分类好的图片
images:存放用于测试的单个图片

2、把image_retraining中的retrain.py文件拖过来(注意:这个文件不能用最新版的)

3、新建retrain.bat文件,内容如下

1
2
3
4
5
6
7
8
python retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 100 ^
--model_dir D:/Tensorflow/models/inception/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir data/train/
pause

文件架构如下图:

3 执行retrain.bat脚本,进行transfer learning

跌倒200次训练完成后,会在当前目录生成output_graph.pboutput_labels.txt两个文件,至此训练完成。可以用这个pb模型文件来测试分类了。

4 测试分类效果

自己写一个测试代码,用自己的图片测试下分类效果:

1
2
3
4
5
6
7
8
9
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt

TEST_IMG_DIR = "D:/Tensorflow/Test Images/"
RETRAIN_DIR = "D:/Tensorflow/retrain/" #模型存放目录
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
lines = tf.gfile.GFile(RETRAIN_DIR + "output_labels.txt").readlines()
uid_to_human = {}

for uid,line in enumerate(lines):
line=line.strip("\n")
uid_to_human[uid] = line

#print(uid_to_human)

def id_to_string(node_id):
if node_id not in uid_to_human:
print("node_id not in uid_to_human")
return ""
return uid_to_human[node_id]

#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(RETRAIN_DIR + "output_graph.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")

with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name("final_result:0")
#遍历用于测试的图片目录
for root,dirs,files in os.walk(RETRAIN_DIR + "images/"):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), "rb").read()
predictions = sess.run(softmax_tensor, {"DecodeJpeg/contents:0" : image_data}) #jpg格式图片
#predictions = sess.run(softmax_tensor, {"DecodeJPGInput:0" : image_data}) #jpg格式图片
predictions = np.squeeze(predictions) #吧结果转为1维数据

#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img = Image.open(image_path)
plt.imshow(img)
plt.axis("off")
plt.show()

#排序
top_k = predictions.argsort()[::-1]
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的概率
score = predictions[node_id]
print("%s (score = %.5f)" % (human_string, score))
print()

D:/Tensorflow/retrain/images/111.jpg

png

pet (score = 0.80940)
flower (score = 0.19060)

D:/Tensorflow/retrain/images/222.jpg

png

flower (score = 0.99097)
pet (score = 0.00903)

꧁༺The༒End༻꧂